if(LITE_WITH_CUDA)
  set(IS_FAKED_KERNEL false CACHE INTERNAL "")
  set(lite_kernel_deps ${lite_kernel_deps} cuda_gemm cuda_scale cuda_elementwise cuda_transpose cudnn_pool cuda_batched_gemm ${math_cuda} ${cuda_kernels} CACHE INTERNAL "")
elseif(LITE_ON_MODEL_OPTIMIZE_TOOL OR LITE_WITH_PYTHON)
  set(IS_FAKED_KERNEL true CACHE INTERNAL "")
else()
  return()
endif()


message(STATUS "compile with lite CUDA kernels")

# basic kernels
add_kernel(mul_compute_cuda CUDA basic SRCS mul_compute.cc)
add_kernel(fc_compute_cuda CUDA basic SRCS fc_compute.cu)
add_kernel(gru_compute_cuda CUDA basic SRCS gru_compute.cu)
add_kernel(matmul_compute_cuda CUDA basic SRCS matmul_compute.cc)
add_kernel(search_group_padding_compute_cuda CUDA basic SRCS search_group_padding_compute.cu)
add_kernel(io_copy_compute_cuda CUDA basic SRCS io_copy_compute.cc)
add_kernel(leaky_relu_compute_cuda CUDA basic SRCS leaky_relu_compute.cu)
add_kernel(abs_compute_cuda CUDA basic SRCS abs_compute.cu)
add_kernel(tanh_compute_cuda CUDA basic SRCS tanh_compute.cu)
add_kernel(relu_compute_cuda CUDA basic SRCS relu_compute.cu)
add_kernel(sigmoid_compute_cuda CUDA basic SRCS sigmoid_compute.cu)
add_kernel(yolo_box_compute_cuda CUDA basic SRCS yolo_box_compute.cu)
add_kernel(sequence_pool_compute_cuda CUDA extra SRCS sequence_pool_compute.cu)
add_kernel(sequence_pool_concat_compute_cuda CUDA extra SRCS sequence_pool_concat_compute.cu)
add_kernel(transpose_compute_cuda CUDA basic SRCS transpose_compute.cu)
add_kernel(nearest_interp_compute_cuda CUDA basic SRCS nearest_interp_compute.cu)
add_kernel(conv2d_cuda CUDA basic SRCS conv_compute.cc)
add_kernel(concat_compute_cuda CUDA basic SRCS concat_compute.cu)
add_kernel(elementwise_compute_cuda CUDA basic SRCS elementwise_compute.cu)
add_kernel(calib_compute_cuda CUDA basic SRCS calib_compute.cu)
add_kernel(layout_compute_cuda CUDA basic SRCS layout_compute.cc)
add_kernel(feed_compute_cuda CUDA basic SRCS feed_compute.cc)
add_kernel(fetch_compute_cuda CUDA basic SRCS fetch_compute.cc)
add_kernel(scale_compute_cuda CUDA basic SRCS scale_compute.cc)
add_kernel(dropout_compute_cuda CUDA basic SRCS dropout_compute.cc)
add_kernel(softmax_compute_cuda CUDA basic SRCS softmax_compute.cu)
add_kernel(pool_compute_cuda CUDA basic SRCS pool_compute.cu)
add_kernel(bilinear_interp_compute_cuda CUDA basic SRCS bilinear_interp_compute.cu)

# extra kernels
add_kernel(search_seq_depadding_compute_cuda CUDA extra SRCS search_seq_depadding_compute.cu)
add_kernel(search_grnn_compute_cuda CUDA extra SRCS search_grnn_compute.cu)
add_kernel(sequence_reverse_compute_cuda CUDA extra SRCS sequence_reverse_compute.cu)
add_kernel(sequence_reverse_embedding_compute_cuda CUDA extra SRCS sequence_reverse_embedding_compute.cu)
add_kernel(sequence_pad_compute_cuda CUDA extra SRCS sequence_pad_compute.cu)
add_kernel(sequence_unpad_compute_cuda CUDA extra SRCS sequence_unpad_compute.cu)
add_kernel(sequence_concat_compute_cuda CUDA extra SRCS sequence_concat_compute.cu)
add_kernel(sequence_mask_compute_cuda CUDA extra SRCS sequence_mask_compute.cu)
add_kernel(sequence_arithmetic_compute_cuda CUDA extra SRCS sequence_arithmetic_compute.cu)
add_kernel(lookup_table_compute_cuda CUDA extra SRCS lookup_table_compute.cu)
add_kernel(attention_padding_mask_compute_cuda CUDA extra SRCS attention_padding_mask_compute.cu)
add_kernel(search_fc_compute_cuda CUDA extra SRCS search_fc_compute.cu)
add_kernel(sequence_topk_avg_pooling_compute_cuda CUDA extra SRCS sequence_topk_avg_pooling_compute.cu)
add_kernel(match_matrix_tensor_compute_cuda CUDA extra SRCS match_matrix_tensor_compute.cu)
add_kernel(search_aligned_mat_mul_compute_cuda CUDA extra SRCS search_aligned_mat_mul_compute.cc)
add_kernel(search_seq_fc_compute_cuda CUDA extra SRCS search_seq_fc_compute.cu)
add_kernel(var_conv_2d_compute_cuda CUDA extra SRCS var_conv_2d_compute.cu)
add_kernel(topk_pooling_compute_cuda CUDA extra SRCS topk_pooling_compute.cu)
add_kernel(assign_value_compute_cuda CUDA extra SRCS assign_value_compute.cu)


# unit test
lite_cc_test(calib_compute_cuda_test SRCS calib_compute_cuda_test.cc DEPS kernels)
nv_test(conv2d_cuda_test SRCS conv_compute_test.cc DEPS kernels)
nv_test(nearest_interp_compute_cuda_test SRCS nearest_interp_compute_test.cc DEPS kernels)
nv_test(leaky_relu_compute_cuda_test SRCS leaky_relu_compute_test.cc DEPS kernels)
nv_test(abs_compute_cuda_test SRCS abs_compute_test.cc DEPS kernels)
nv_test(tanh_compute_cuda_test SRCS tanh_compute_test.cc DEPS kernels)
nv_test(relu_compute_cuda_test SRCS relu_compute_test.cc DEPS kernels)
nv_test(sigmoid_compute_cuda_test SRCS sigmoid_compute_test.cc DEPS kernels)
nv_test(yolo_box_compute_cuda_test SRCS yolo_box_compute_test.cc DEPS kernels)
nv_test(transpose_compute_cuda_test SRCS transpose_compute_test.cc DEPS kernels)
nv_test(search_group_padding_compute_cuda_test SRCS search_group_padding_compute_test.cc DEPS kernels)
nv_test(concat_compute_cuda_test SRCS concat_compute_test.cc DEPS kernels)
nv_test(elementwise_compute_cuda_test SRCS elementwise_compute_test.cc DEPS kernels)
nv_test(softmax_compute_cuda_test SRCS softmax_compute_test.cc DEPS kernels)
#nv_test(layout_cuda_test SRCS layout_compute_test.cc DEPS kernels)
nv_test(mul_compute_cuda_test SRCS mul_compute_test.cc DEPS kernels)
nv_test(fc_compute_cuda_test SRCS fc_compute_test.cc DEPS kernels)
nv_test(gru_compute_cuda_test SRCS gru_compute_test.cc DEPS kernels)
nv_test(matmul_compute_cuda_test SRCS matmul_compute_test.cc DEPS kernels)
nv_test(dropout_compute_cuda_test SRCS dropout_compute_test.cc DEPS kernels)
nv_test(bilinear_interp_compute_cuda_test SRCS bilinear_interp_compute_test.cc DEPS kernels)
#nv_test(pool_compute_cuda_test SRCS pool_compute_test.cc DEPS kernels)

if(LITE_BUILD_EXTRA)
    nv_test(search_seq_depadding_compute_cuda_test SRCS search_seq_depadding_compute_test.cc DEPS kernels)
    #nv_test(match_matrix_tensor_compute_cuda_test SRCS match_matrix_tensor_compute_test.cc DEPS kernels)
    nv_test(search_grnn_compute_cuda_test SRCS search_grnn_compute_test.cc DEPS kernels)
    nv_test(sequence_pool_compute_cuda_test SRCS sequence_pool_compute_test.cc DEPS kernels)
    nv_test(lookup_table_compute_cuda_test SRCS lookup_table_compute_test.cc DEPS kernels)
    nv_test(search_aligned_mat_mul_compute_cuda_test SRCS search_aligned_mat_mul_compute_test.cc DEPS kernels)
    nv_test(search_seq_fc_compute_cuda_test SRCS search_seq_fc_compute_test.cc DEPS kernels)
    nv_test(sequence_reverse_compute_cuda_test SRCS sequence_reverse_compute_test.cc DEPS kernels)
    nv_test(sequence_pad_compute_cuda_test SRCS sequence_pad_compute_test.cc DEPS kernels)
    nv_test(sequence_unpad_compute_cuda_test SRCS sequence_unpad_compute_test.cc DEPS kernels)
    nv_test(sequence_mask_compute_cuda_test SRCS sequence_mask_compute_test.cc DEPS kernels)
    nv_test(var_conv_2d_compute_cuda_test SRCS var_conv_2d_compute_test.cc DEPS kernels)
    #nv_test(sequence_concat_compute_cuda_test SRCS sequence_concat_compute_test.cc DEPS kernels)
    #nv_test(attention_padding_mask_compute_cuda_test SRCS attention_padding_mask_compute_test.cc DEPS kernels)
    nv_test(sequence_arithmetic_compute_cuda_test SRCS sequence_arithmetic_compute_test.cc DEPS kernels)
    #nv_test(search_fc_cuda_test SRCS search_fc_compute_test.cc DEPS kernels)
    nv_test(topk_pooling_compute_cuda_test SRCS topk_pooling_compute_test.cc DEPS kernels)
    nv_test(assign_value_compute_cuda_test SRCS assign_value_compute_test.cc DEPS kernels)
endif()
